{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "93894144-fad1-446b-83e8-70402a1a87ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from transformers import pipeline\n",
    "import pycountry\n",
    "from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8a9515bb-3f25-4e55-9050-86775ad58c91",
   "metadata": {},
   "outputs": [],
   "source": [
    "df=pd.read_csv(r'C:\\Users\\Yasaman\\Arab_spring_scholarly_attention\\Validating\\final_annotated.csv')\n",
    "df['Text']=df['Title']+' '+df['Abstract']\n",
    "df['union_annotation'] = df['union_annotation'].apply(lambda x: eval(x) if isinstance(x, str) else x)\n",
    "df['intersection_annotation'] = df['intersection_annotation'].apply(lambda x: eval(x) if isinstance(x, str) else x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "338a7890-3c53-4423-8e2a-0059619423da",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at dslim/bert-large-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "100%|██████████| 1000/1000 [30:55<00:00,  1.86s/it]\n"
     ]
    }
   ],
   "source": [
    "# Load tokenizer and model\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"dslim/bert-large-NER\")\n",
    "model = AutoModelForTokenClassification.from_pretrained(\"dslim/bert-large-NER\")\n",
    "\n",
    "nlp = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n",
    "\n",
    "# Function to extract locations and filter by country names\n",
    "def extract_and_filter_locations(text):\n",
    "    ner_results = nlp(text)\n",
    "    locations = [\n",
    "        entity[\"word\"]\n",
    "        for entity in ner_results\n",
    "         if ( entity[\"entity\"] == \"B-LOC\" or entity[\"entity\"] == \"I-LOC\")\n",
    "    ]\n",
    "\n",
    "    final=coco.convert(list(set(locations)), to='ISO3') \n",
    "    if type(final)==str:\n",
    "        final=[final]\n",
    "\n",
    "    final=[x.lower() for x in final if x!='not found']\n",
    "   \n",
    "    return final\n",
    "\n",
    "# Apply the function to the DataFrame with tqdm\n",
    "tqdm.pandas()  # Initialize tqdm for pandas\n",
    "import logging\n",
    "import country_converter as coco\n",
    "coco_logger = coco.logging.getLogger()\n",
    "coco_logger.setLevel(logging.CRITICAL)\n",
    "df[\"locations\"] = df[\"Text\"].progress_apply(extract_and_filter_locations)\n",
    "\n",
    "# Display the result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9cbdf79e-8fc6-43a0-a9ee-0a1ace486b9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(r'C:\\Users\\Yasaman\\Arab_spring_scholarly_attention\\Validating\\bert_large_annotated.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9fbaf1a5-473c-43e5-b41a-7cc04d604399",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "overall accuracy (union) 0.727\n",
      "overall accuracy (intersection) 0.718\n"
     ]
    }
   ],
   "source": [
    "print('overall accuracy (union)', sum(df['locations']==df['union_annotation'])/1000)\n",
    "print('overall accuracy (intersection)', sum(df['locations']==df['intersection_annotation'])/1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2dea575a-e91a-4aa3-8d9e-2b66a5b3d291",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Yasaman\\AppData\\Local\\Temp\\ipykernel_28352\\2343847047.py:1: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
      "  sample_accuracies = df.groupby(\"SampleGroup\").apply(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SampleGroup</th>\n",
       "      <th>accuracy_union</th>\n",
       "      <th>accuracy_intersection</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>with_mention_arab</td>\n",
       "      <td>0.620000</td>\n",
       "      <td>0.580000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>with_mention</td>\n",
       "      <td>0.632143</td>\n",
       "      <td>0.582143</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>field_20</td>\n",
       "      <td>0.907692</td>\n",
       "      <td>0.909615</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         SampleGroup  accuracy_union  accuracy_intersection\n",
       "0  with_mention_arab        0.620000               0.580000\n",
       "1       with_mention        0.632143               0.582143\n",
       "2           field_20        0.907692               0.909615"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_accuracies = df.groupby(\"SampleGroup\").apply(\n",
    "    lambda x: pd.Series({\n",
    "        \"accuracy_union\": (x.apply(lambda row: set(row['locations']) == set(row['union_annotation']), axis=1).mean()),\n",
    "        \"accuracy_intersection\": (x.apply(lambda row: set(row['locations']) == set(row['intersection_annotation']), axis=1).mean())\n",
    "    })\n",
    ").reset_index()\n",
    "\n",
    "sample_accuracies = sample_accuracies.sort_values(by='accuracy_union').reset_index(drop=True)\n",
    "sample_accuracies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3aabf449-ed59-403a-a325-e2652fe72d78",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Yasaman\\AppData\\Local\\Temp\\ipykernel_28352\\388812476.py:1: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
      "  sample_accuracies = df.groupby(\"SampleGroup\").apply(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SampleGroup</th>\n",
       "      <th>jaccard_union</th>\n",
       "      <th>jaccard_intersection</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>with_mention</td>\n",
       "      <td>0.719226</td>\n",
       "      <td>0.675119</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>with_mention_arab</td>\n",
       "      <td>0.796913</td>\n",
       "      <td>0.746531</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>field_20</td>\n",
       "      <td>0.921538</td>\n",
       "      <td>0.920256</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         SampleGroup  jaccard_union  jaccard_intersection\n",
       "0       with_mention       0.719226              0.675119\n",
       "1  with_mention_arab       0.796913              0.746531\n",
       "2           field_20       0.921538              0.920256"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_accuracies = df.groupby(\"SampleGroup\").apply(\n",
    "    lambda x: pd.Series({\n",
    "        \"jaccard_union\": (\n",
    "            x.apply(\n",
    "                lambda row: len(set(row['locations']).intersection(set(row['union_annotation']))) /\n",
    "                            len(set(row['locations']).union(set(row['union_annotation'])))\n",
    "                if len(set(row['locations']).union(set(row['union_annotation']))) > 0 else 1,\n",
    "                axis=1\n",
    "            ).mean()\n",
    "        ),\n",
    "        \"jaccard_intersection\": (\n",
    "            x.apply(\n",
    "                lambda row: len(set(row['locations']).intersection(set(row['intersection_annotation']))) /\n",
    "                            len(set(row['locations']).union(set(row['intersection_annotation'])))\n",
    "                if len(set(row['locations']).union(set(row['intersection_annotation']))) > 0 else 1,\n",
    "                axis=1\n",
    "            ).mean()\n",
    "        )\n",
    "    })\n",
    ").reset_index()\n",
    "\n",
    "sample_accuracies = sample_accuracies.sort_values(by='jaccard_union').reset_index(drop=True)\n",
    "sample_accuracies"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
